import matplotlib.pyplot as plt
import numpy as np
import torch
from utils.utils_data import prepare_mnist_data, prepare_kmnist_data, prepare_fashion_data, prepare_cifar10_data, convert_to_binary_data, generate_original_binary_loaders, synth_uupr_train_dataset, synth_test_dataset, generate_uupr_loaders, get_model, generate_uupr_data,generate_small_loaders
import argparse, datetime
from utils.utils_models import linear_model, mlp_model
from utils.utils_loss import logistic_loss, sigmoid_loss, unhinged_loss, exp_loss, hinge_loss
from cifar_models import resnet
from algorithms import *

torch.cuda.manual_seed_all(0); torch.manual_seed(0);

parser = argparse.ArgumentParser()

parser.add_argument('-lr', help='optimizer\'s learning rate', default=1e-3, type=float)
parser.add_argument('-bs', help='batch_size of ordinary labels.', default=256, type=int)
parser.add_argument('-ds', help='specify a dataset', default='mnist', type=str, required=False)
parser.add_argument('-mo', help='model name', default='linear', choices=['linear', 'mlp', 'resnet'], type=str, required=False)
parser.add_argument('-ep', help='number of epochs', type=int, default=100)
parser.add_argument('-wd', help='weight decay', default=1e-3, type=float)
parser.add_argument('-lo', help='specify a loss function', default='logistic', type=str, choices=['sigmoid','logistic','unhinged','exp','hinge'], required=False)
parser.add_argument('-me', help='specify a method', default='PcompTeacher', type=str, choices=['PcompUnbiased','PcompReLU','BinaryBiased','RankPruning', 'PcompABS', 'NoisyUnbiased', 'PcompTeacher'], required=False)
parser.add_argument('-uci', help = 'Is UCI datasets?', default=1, type=int, choices=[0,1], required=False)
parser.add_argument('-seed', help = 'Random seed', default=0, type=int, required=False)
parser.add_argument('-n', help = 'number of unlabeled data pairs', default=15000, type=int, required=False)
parser.add_argument('-prior', help = 'class (positive) prior', default=0.5, type=float, required=False)
parser.add_argument('-gpu', help = 'used gpu id', default='0', type=str, required=False)
parser.add_argument('-ema_weight', help = 'variance of Gaussian noise', default=0.01, type=float, required=False)
parser.add_argument('-ema_alpha', help = 'variance of Gaussian noise', default=0.97, type=float, required=False)

args = parser.parse_args()
device = torch.device("cuda:"+args.gpu if torch.cuda.is_available() else "cpu")

if args.lo == 'sigmoid':
    loss_fn = sigmoid_loss
elif args.lo == 'logistic':
    loss_fn = logistic_loss
elif args.lo == 'unhinged':
    loss_fn = unhinged_loss
elif args.lo == 'exp':
    loss_fn = exp_loss
elif args.lo == 'hinge':
    loss_fn = hinge_loss
    
if args.ds == 'usps':
    args.n =2000
elif args.ds == 'pendigits':
    args.n = 2500
elif args.ds =='optdigits':
    args.n = 1000
elif args.ds == 'cnae-9':
    args.n = 200
   
xp, xn, real_yp, real_yn, given_yp, given_yn, xt, yt, dim = generate_uupr_data(args)

#print(xp.shape, xn.shape, real_yp.shape, real_yn.shape, given_yp.shape, given_yn.shape)
uupr_pos_train_loader, uupr_neg_train_loader, given_train_loader, real_train_loader, test_loader = generate_uupr_loaders(xp, xn, given_yp, given_yn, real_yp, real_yn, xt, yt, args.bs)
small_train_loader, test_loader = generate_small_loaders(xp, xn, given_yp, given_yn, real_yp, real_yn, xt, yt, args.bs)
model = get_model(args, dim, device)

if args.me == 'BinaryBiased':
    BinaryBiased_Acc = BinaryBiased(model, uupr_pos_train_loader, uupr_neg_train_loader, test_loader, args, loss_fn, device)
    print("BinaryBiased Accuracy:", BinaryBiased_Acc)
elif args.me == 'PcompUnbiased':
    PcompUnbiased_Acc = PcompUnbiased(model, given_train_loader, test_loader, args, loss_fn, device)
    print("PcompUnbiased Accuracy:", PcompUnbiased_Acc)
elif args.me == 'PcompReLU':
    PcompReLU_Acc = PcompReLU(model, given_train_loader, test_loader, args, loss_fn, device)
    print("PcompReLU Accuracy:", PcompReLU_Acc)
elif args.me == 'RankPruning':
    RankPruning_Acc = RankPruning(model, given_train_loader, test_loader, args, loss_fn, device)
    print("RankPruning Accuracy:", RankPruning_Acc)
elif args.me == 'PcompTeacher':
    ema_model = get_model(args, dim, device)
    PcompTeacher_Acc = PcompTeacher(model, ema_model, given_train_loader, test_loader, args, loss_fn, device)
    print("PcompTeacher Accuracy:", PcompTeacher_Acc)
elif args.me == 'NoisyUnbiased':
    NoisyUnbiased_Acc = NoisyUnbiased(model, given_train_loader, test_loader, args, loss_fn, device)
    print("NoisyUnbiased Accuracy:", NoisyUnbiased_Acc)
elif args.me == 'PcompABS':
    PcompABS_Acc = PcompABS(model, given_train_loader, test_loader, args, loss_fn, device)
    print("PcompABS Accuracy:", PcompABS_Acc)

print('method:{}    lr:{}    wd:{}'.format(args.me, args.lr, args.wd))
print('loss:{}    prior:{}'.format(args.lo, args.prior))
print('model:{}    dataset:{}'.format(args.mo, args.ds))
print('num of sample:{}'.format(args.n))
if args.me =='PcompTeacher':
    print('alpha:{}    weight:{}'.format(args.ema_alpha, args.ema_weight))